#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations

import copy
import os
import warnings

import requests

from rich.prompt import Confirm

from autoapi.common.json_handler import read_json_file
from autoapi.common.yaml_handler import write_yaml
from autoapi.core.get_conf import autoapi_config
from autoapi.core.path_conf import autoapi_path
from autoapi.utils.data_manage.base_format import format_value
from autoapi.utils.file_control import get_file_property
from autoapi.utils.rich_console import console
from autoapi.utils.time_control import get_current_timestamp


class SwaggerParser:
    def __init__(self, version: int | None = None, data: dict | None = None):
        """
        初始化参数

        :param version:
        :param data:
        """
        self.version = version
        self.data = data

    def import_openapi_to_yaml(
        self, openapi_source: str, project: str | None = None
    ) -> None:
        """
        导入 openapi 数据到 yaml

        :param openapi_source:
        :param project:
        :return:
        """
        global warn_operationId
        self.get_swagger_data(openapi_source)
        try:
            # 测试用例数据配置项
            case_config = {}
            case_config.update(
                {
                    "allure": {
                        "epic": self.data["info"]["title"],
                        "feature": (
                            self.data["tags"][0]["name"]
                            if self.data.get("tags") is not None
                            else "未知"
                        ),
                        "story": self.data["info"]["description"],
                    },
                    "request": {"env": autoapi_config.ENV_DEFAULT},
                    "module": self.data["info"]["title"],
                }
            )

            # 测试用例数据
            case = {}
            # 子目录
            tag = ""
            # 子目录测试用例数据
            tag_case = {}
            # 根目录测试用例数据
            root_case = {}

            is_tag = Confirm.ask(
                "❓ 是否按 openapi 标签划分数据存放目录?", default=True
            )
            for url, values in self.data["paths"].items():
                for method, values_map in values.items():
                    params = self.get_swagger_params(values_map)
                    headers = self.get_swagger_headers(values_map)
                    if self.version == 2:
                        data = (
                            self.get_swagger_request_data(values_map)
                            if method.lower() != "get"
                            else None
                        )
                        files = (
                            self.get_swagger_request_files(values_map)
                            if method.lower() != "get"
                            else None
                        )
                    else:
                        try:
                            schema = (
                                values_map["requestBody"]["content"][
                                    headers["Content-Type"]
                                ]["schema"]["$ref"].split("/")[-1]
                                if values_map.get("requestBody") is not None
                                else None
                            )
                        except KeyError:
                            schema = values_map["requestBody"]["content"][
                                headers["Content-Type"]
                            ]["schema"]
                            warn_operationId = values_map["operationId"]
                        data = (
                            self.get_swagger_request_data(schema)
                            if schema is not None
                            else None
                        )
                        files = (
                            self.get_swagger_request_files(schema)
                            if schema is not None
                            else None
                        )
                    data_type = None
                    if headers is not None:
                        data_type = "json" if "application/json" in headers else "data"
                    case = copy.deepcopy(case)
                    case.update(
                        {
                            "name": values_map.get("summary"),
                            "case_id": values_map.get("operationId"),
                            "description": values_map.get("description"),
                            "is_run": values_map.get("deprecated", True),
                            "request": {
                                "method": method,
                                "url": url,
                                "params": params,
                                "headers": headers,
                                "data_type": data_type,
                                "data": data,
                                "files": files,
                            },
                        }
                    )
                    tags = values_map.get("tags")
                    # 按 tag 划分目录
                    if is_tag:
                        if tags is not None:
                            dp_tag = copy.deepcopy(tag)
                            tag = tags[0]
                            (
                                tag_case.update({tag: [case]})
                                if dp_tag != tag
                                else tag_case[tag].append(case)
                            )
                        else:
                            root_case = copy.deepcopy(root_case)
                            (
                                root_case.update({"root": [case]})
                                if root_case.get("root") is None
                                else root_case["root"].append(case)
                            )
                    # 不按 tag 划分目录
                    else:
                        if tags:
                            dp_tag = copy.deepcopy(tag)
                            tag = tags[0]
                            (
                                root_case.update({tag: [case]})
                                if dp_tag != tag
                                else root_case[tag].append(case)
                            )
                        else:
                            root_case = copy.deepcopy(root_case)
                            (
                                root_case.update({"root": [case]})
                                if root_case.get("root") is None
                                else root_case["root"].append(case)
                            )
            if len(tag_case) > 0 or len(root_case) > 0:
                # 文件创建提醒
                file_list = []
                if len(tag_case) > 0:
                    for k, v in tag_case.items():
                        tag_filename = os.sep.join(
                            [
                                project or autoapi_config.PROJECT_NAME,
                                k,
                                (
                                    get_file_property(openapi_source)[1] + ".yaml"
                                    if not openapi_source.startswith("http")
                                    else f"openapi_{k}.yaml"
                                ),
                            ]
                        )
                        file_list.append(tag_filename)
                if len(root_case) > 0:
                    for k, v in root_case.items():
                        if k == "root":
                            root_filename = os.sep.join(
                                [
                                    project or autoapi_config.PROJECT_NAME,
                                    (
                                        get_file_property(openapi_source)[1] + ".yaml"
                                        if not openapi_source.startswith("http")
                                        else f"openapi_{get_current_timestamp()}.yaml"
                                    ),
                                ]
                            )
                        else:
                            root_filename = os.sep.join(
                                [
                                    project or autoapi_config.PROJECT_NAME,
                                    (
                                        get_file_property(openapi_source)[1] + ".yaml"
                                        if not openapi_source.startswith("http")
                                        else f"openapi_{k}.yaml"
                                    ),
                                ]
                            )
                        file_list.append(root_filename)
                console.print("⚠️ 即将创建以下数据文件:")
                for i in file_list:
                    console.print(f"\n\tdata\\test_data\\{i}")
                is_force_write = Confirm.ask(
                    "\n👁️ 请检查是否存在同名文件, 此操作将强制覆盖写入所有数据文件, 是否继续执行? (此操作不可逆)",
                    default=False,
                )
                # 强制写入
                if is_force_write:
                    console.print("⏳ 奋力导入中...")
                    # 写入项目 tag 目录
                    if len(tag_case) > 0:
                        for k, v in tag_case.items():
                            case_config["allure"]["feature"] = case_config["module"] = k
                            case_file_data = {"config": case_config, "test_steps": v}
                            write_yaml(
                                autoapi_path.case_data_dir,
                                os.sep.join(
                                    [
                                        project or autoapi_config.PROJECT_NAME,
                                        k,
                                        (
                                            get_file_property(openapi_source)[1]
                                            + ".yaml"
                                            if not openapi_source.startswith("http")
                                            else f"openapi_{k}.yaml"
                                        ),
                                    ]
                                ),
                                case_file_data,
                                mode="w",
                            )
                    # 写入项目根目录
                    if len(root_case) > 0:
                        for k, v in root_case.items():
                            if k == "root":
                                filename = (
                                    get_file_property(openapi_source)[1] + ".yaml"
                                    if not openapi_source.startswith("http")
                                    else f"openapi_{get_current_timestamp()}.yaml"
                                )
                            else:
                                filename = (
                                    get_file_property(openapi_source)[1] + ".yaml"
                                    if not openapi_source.startswith("http")
                                    else f"openapi_{k}.yaml"
                                )
                            case_file_data = {"config": case_config, "test_steps": v}
                            write_yaml(
                                autoapi_path.case_data_dir,
                                os.sep.join(
                                    [project or autoapi_config.PROJECT_NAME, filename]
                                ),
                                case_file_data,
                                mode="w",
                            )
                # 选择写入
                else:
                    console.print("⚠️ 已取消强制覆写入所有数据文件")
                    is_next = Confirm.ask(
                        "❓ 是否进行逐一选择创建数据文件?", default=True
                    )
                    if is_next:
                        # 写入项目 tag 目录
                        if len(tag_case) > 0:
                            for k, v in tag_case.items():
                                case_config["allure"]["feature"] = case_config[
                                    "module"
                                ] = k
                                case_file_data = {
                                    "config": case_config,
                                    "test_steps": v,
                                }
                                tag_filename = os.sep.join(
                                    [
                                        project or autoapi_config.PROJECT_NAME,
                                        k,
                                        (
                                            get_file_property(openapi_source)[1]
                                            + ".yaml"
                                            if not openapi_source.startswith("http")
                                            else f"openapi_{k}.yaml"
                                        ),
                                    ]
                                )
                                is_write = Confirm.ask(
                                    f"❓ 是否需要创建 {tag_filename} 数据文件?",
                                    default=True,
                                )
                                if is_write:
                                    write_yaml(
                                        autoapi_path.case_data_dir,
                                        tag_filename,
                                        case_file_data,
                                        mode="w",
                                    )
                        # 写入项目根目录
                        if len(root_case) > 0:
                            for k, v in root_case.items():
                                if k == "root":
                                    root_filename = (
                                        get_file_property(openapi_source)[1] + ".yaml"
                                        if not openapi_source.startswith("http")
                                        else f"openapi_{get_current_timestamp()}.yaml"
                                    )
                                else:
                                    root_filename = (
                                        get_file_property(openapi_source)[1] + ".yaml"
                                        if not openapi_source.startswith("http")
                                        else f"openapi_{k}.yaml"
                                    )
                                is_write = Confirm.ask(
                                    f"❓ 是否需要创建 {root_filename} 数据文件?",
                                    default=True,
                                )
                                if is_write:
                                    case_file_data = {
                                        "config": case_config,
                                        "test_steps": v,
                                    }
                                    write_yaml(
                                        autoapi_path.case_data_dir,
                                        os.sep.join(
                                            [
                                                project or autoapi_config.PROJECT_NAME,
                                                root_filename,
                                            ]
                                        ),
                                        case_file_data,
                                        mode="w",
                                    )
            console.print("✅ 导入 openapi 数据成功")
        except Exception as e:
            raise e

    def get_swagger_data(self, openapi_source: str) -> None:
        """
        请求 swagger 数据

        :param openapi_source:
        :return:
        """
        if openapi_source.startswith("http"):
            request = requests.session()
            request.trust_env = False
            data = request.get(openapi_source).json()
            openapi = data.get("openapi")
            swagger = data.get("swagger")
            if not (openapi or swagger):
                raise ValueError("请输入正确的 openapi 地址")
        else:
            data = read_json_file(openapi_source)
            openapi = data.get("openapi")
            swagger = data.get("swagger")
            if not (openapi or swagger):
                raise ValueError("获取 openapi 数据失败，请使用合法的 openapi 文件")
        if openapi is not None and int(openapi.split(".")[0]) == 3:
            self.version = 3
        elif swagger == "2.0":
            self.version = 2
        else:
            raise Exception("不受支持的 openapi 版本")
        self.data = data

    def get_swagger_params(self, value: dict) -> dict | None:
        """
        获取查询参数

        :param value:
        :return:
        """
        data = {}
        if self.version == 2:
            params = value.get("parameters")
            if params is not None:
                for i in params:
                    if i["in"] == "path":
                        data = {i["name"]: None}
            return data if len(data) > 0 else None
        else:
            params = value.get("parameters")
            if params is not None:
                for i in params:
                    data = {i["name"]: None}
            return data if len(data) > 0 else None

    def get_swagger_headers(self, value: dict) -> dict | None:
        """
        获取查询参数

        :param value:
        :return:
        """
        data = {}
        if self.version == 2:
            headers = value.get("consumes")
            if headers is not None:
                data = {"Content-Type": headers[0]}
            return data if len(data) > 0 else None
        else:
            headers = value.get("requestBody")
            if headers is not None:
                content = headers["content"]
                data = {"Content-Type": f"{[k for k in content.keys()][0]}"}
            return data if len(data) > 0 else None

    def get_swagger_schema_data(self, name: str) -> dict:
        """
        获取 schema 数据

        :param name:
        :return:
        """
        if self.version == 2:
            data = self.data.get("definitions").get(name)
        else:
            data = self.data.get("components").get("schemas").get(name)
        return data

    def get_swagger_request_data(self, value: dict | str) -> dict | None:
        """
        获取请求 data

        :param value:
        :return:
        """
        data = {}
        if self.version == 2:
            # 确保 value 是字典类型
            if isinstance(value, dict) and "parameters" in value:
                if len(value["parameters"]) > 0:
                    for i in value["parameters"]:
                        # 确保 i 是字典类型
                        if isinstance(i, dict):
                            if i.get("type") is None:
                                data[i["name"]] = format_value("object")  
                            else:
                                if i.get("type") != "file":
                                    data[i["name"]] = format_value(i.get("type", "object")) 
            return data if len(data) > 0 else None
        else:
            if not isinstance(value, dict):
                schema_data = self.get_swagger_schema_data(value)
                for k, v in schema_data["properties"].items():
                    if v.get("format") is None:
                        data[k] = format_value(v.get("type", "object"))
                    else:
                        if v.get("format") != "binary":
                            data[k] = format_value(v.get("type", "object"))
            else:
                # 请求体使用了非 schema 格式入参
                warnings.warn(f"接口(ID) {warn_operationId} 使用了非 schema 入参")
                if value.get("type") is None:
                    data[value["title"]] = format_value("object")
                else:
                    if value.get("type") != "file":
                        data[value["title"]] = format_value(value.get("type", "object"))
            return data if len(data) > 0 else None

    def get_swagger_request_files(self, value: str | dict) -> dict | None:
        """
        获取请求 files

        :param value:
        :return:
        """
        files = {}
        if self.version == 2:
            # 确保 value 是字典类型
            if isinstance(value, dict) and "parameters" in value:
                for v in value["parameters"]:
                    # 确保 v 是字典类型
                    if isinstance(v, dict):
                        if v.get("type") == "file":
                            files[v.get("name")] = format_value("object")
            return files if len(files) > 0 else None
        else:
            if not isinstance(value, dict):
                schema_data = self.get_swagger_schema_data(value)
                for k, v in schema_data["properties"].items():
                    if v.get("format") == "binary":
                        files[k] = format_value(v.get("type", "object"))
                return files if len(files) > 0 else None
